{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reinforcement Learning Tutorial for POMDPs.jl\n",
    "\n",
    "*inspired from Stanford Reinforcement Learning class*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "using POMDPs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem overview"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mars Exploration\n",
    "\n",
    "\n",
    "Let's consider a 1D gridworld and 2 different rewards at each extremities. \n",
    "\n",
    "![wall-e-mdp](initial_state.png)\n",
    "\n",
    "The environment can be modeled as follow\n",
    "\n",
    "- 10 states: 1,2,3,4,5,6,7,8,9,10\n",
    "- 2 actions: left, right\n",
    "- one episode lasts until a reward is found\n",
    "- there is a reward of +1 in state 1 (Eve) and +2 in state 10 (the plant)\n",
    "- Wall-E starts at 3\n",
    "\n",
    "# Implementation\n",
    "\n",
    "Let's start by defining a type containing all the problem parameters. In Reinforcement Learning, problems are usually framed as Markov decision processes (MDP). Our new problem must then be defined as a subtype of the abstract type `MDP{S,A}` where `S`and `A` are the state type and the action type respectively.\n",
    "For more information about MDPs please refer to this [MDP tutorial](http://nbviewer.jupyter.org/github/sisl/POMDPs.jl/blob/master/examples/GridWorld.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "mutable struct MarsExp <: MDP{Int64, Symbol}\n",
    "    r_left::Float64\n",
    "    r_right::Float64\n",
    "    start::Int64\n",
    "    γ::Float64\n",
    "    MarsExp(;r_left::Float64 = 1., r_right::Float64 = 10.,start::Int64 = 5, γ::Float64 = 0.5) = new(r_left, r_right, start, γ)\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**State Space**\n",
    "\n",
    "In this tutorial we consider a discrete environment. We can enumerate all the states and provide an indexing function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "function POMDPs.states(mdp::MarsExp)\n",
    "    return 1:1:10\n",
    "end\n",
    "POMDPs.state_index(mdp::MarsExp, s::Int64) = s\n",
    "POMDPs.n_states(mdp::MarsExp) = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Action Space**\n",
    "\n",
    "The action space is discrete and consists of only two actions: moving left or right. Again, we can enumerate and index the actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "function POMDPs.actions(mdp::MarsExp)\n",
    "    return [:left, :right]\n",
    "end\n",
    "POMDPs.action_index(mdp::MarsExp, a::Symbol) = a == :left ? 1 : 2\n",
    "POMDPs.n_actions(mdp::MarsExp) = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Generative Model**\n",
    "\n",
    "In Reinforcement Learning problems, we do not assume prior knowledge on the environment. The agent learns the policy as it interacts with the environment. As a consequence all we need to provide is a generative model that allows to transition from one state to another. In addition we must define a reward function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "function POMDPs.generate_s(mdp::MarsExp, s::Int64, a::Symbol, rng::AbstractRNG)\n",
    "    if a == :left\n",
    "        return max(1, s-1)\n",
    "    elseif a == :right\n",
    "        return min(10, s+1)\n",
    "    end\n",
    "end\n",
    "function POMDPs.reward(mdp::MarsExp, s::Int64, a::Symbol, sp::Int64)\n",
    "    if sp == 1\n",
    "        return mdp.r_left\n",
    "    elseif sp == 10\n",
    "        return mdp.r_right\n",
    "    else\n",
    "        return 0.0\n",
    "    end\n",
    "end     "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** Instead of defining `generate_s`and `reward`, you can also define the function `generate_sr` that returns a tuple (next state, reward)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initial states**\n",
    "\n",
    "In the proposed problem, the agent always arrives in the world at the state specified in the MDP type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "function POMDPs.initial_state(mdp::MarsExp, rng::AbstractRNG)\n",
    "    return mdp.start \n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Terminal states**\n",
    "\n",
    "Once a reward is found, the game is over."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "function POMDPs.isterminal(mdp::MarsExp, s::Int64)\n",
    "    return s == 1 || s == 10\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Discount Factor**\n",
    "\n",
    "The reward received is discounted by a multiplicative factor. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "POMDPs.discount(mdp::MarsExp) = mdp.γ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize the problem\n",
    "\n",
    "Let's generate an instance of the exploration problem. Here we set the start state on cell 4. Feel free to change this value as well as the rewards and the discount factor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MarsExp(1.0, 10.0, 4, 0.9)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mdp = MarsExp(start=4, γ=0.9) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solve with Q-learning\n",
    "\n",
    "To solve the problem we can use the `TabularTDLearning` package which provides three model-free learning algorithms:\n",
    "- Q Learning\n",
    "- SARSA\n",
    "- SARSA $\\lambda$\n",
    "\n",
    "You can install the package by running:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Package already installed"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mCloning TabularTDLearning from https://github.com/JuliaPOMDP/TabularTDLearning.jl\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "POMDPs.add(\"TabularTDLearning\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All the solvers requires to specify an exploration strategy. Some common strategies like epsilon greedy are implemented in th `POMDPToolbox` package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mRecompiling stale cache file C:\\Users\\Maxime\\.julia\\lib\\v0.6\\POMDPToolbox.ji for module POMDPToolbox.\n",
      "\u001b[39m"
     ]
    }
   ],
   "source": [
    "using TabularTDLearning, POMDPToolbox"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check that we implemented all the functions required by the QLearning solver. This can be done via the `@requirements_info` macro. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "INFO: POMDPs.jl requirements for \u001b[34msolve(::QLearningSolver, ::Union{POMDPs.MDP,POMDPs.POMDP})\u001b[39m and dependencies. ([✔] = implemented correctly; [X] = missing)\n",
      "\n",
      "For \u001b[34msolve(::QLearningSolver, ::Union{POMDPs.MDP,POMDPs.POMDP})\u001b[39m:\n",
      "\u001b[32m  [✔] initial_state(::MarsExp, ::AbstractRNG)\u001b[39m\n",
      "\u001b[32m  [✔] generate_sr(::MarsExp, ::Int64, ::Symbol, ::AbstractRNG)\u001b[39m\n",
      "\u001b[32m  [✔] state_index(::MarsExp, ::Int64)\u001b[39m\n",
      "\u001b[32m  [✔] n_states(::MarsExp)\u001b[39m\n",
      "\u001b[32m  [✔] n_actions(::MarsExp)\u001b[39m\n",
      "\u001b[32m  [✔] action_index(::MarsExp, ::Symbol)\u001b[39m\n",
      "\u001b[32m  [✔] discount(::MarsExp)\u001b[39m\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "true"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@requirements_info QLearningSolver(MarsExp()) MarsExp()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialize the solver**\n",
    "\n",
    "Use the keywords arguments to set the desired hyperparameters. \n",
    "\n",
    "**Tip:** run `?QLearningSolver` to toggle the documentation and have a list of all the hyper-parameters and their default value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "solver = QLearningSolver(mdp, learning_rate=0.1, n_episodes=100, max_episode_length=50, eval_every=50, n_eval_traj=100, \n",
    "                         exp_policy = EpsGreedyPolicy(mdp, 0.9));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready to solve for the optimal policy!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "On Iteration 50, Returns: 0.8100000000000012\n",
      "On Iteration 100, Returns: 5.904899999999999\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "POMDPToolbox.ValuePolicy{Any}(MarsExp(1.0, 10.0, 4, 0.9), [0.0 0.0; 0.997261 2.63477; … ; 5.45566 9.88027; 0.0 0.0], Any[:left, :right])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "policy = solve(solver, mdp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We provided some rendering functions to visualize the policy of the exploration problem. \n",
    "\n",
    "\n",
    "*You need two packages for this: [Cairo](https://github.com/JuliaGraphics/Cairo.jl) and [Colors](https://github.com/JuliaGraphics/Colors.jl), they can be easily installed using `Pkg.add`.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "render_policy (generic function with 1 method)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "include(\"render.jl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The arrow represents the action to take and the number is the value associated to the state\n",
    "\n",
    "**Note:** The value of the terminal state is zero. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAABkCAIAAADVI9l0AAAABmJLR0QA/wD/AP+gvaeTAAAeQ0lEQVR4nO3deVyTV74/8I8UEUR2ZNdIAUERV4pVWrEMLxd+1qHL6DjXW8da1GrHcbl2ymirXayd9o7662Yt4zh422Fw2ilVXygO1woWl4iipCAIiMEAAbNAIEBCxPtHAuQJWZ7EQp423/fLfwjnhK/b88k5z3nOGfHgwQMQQgghjsrJ3gUQQggh9kRBSAghxKFREBJCCHFoFISEEEIcGgUhIYQQh0ZBSAghxKFREBJCCHFoFISEEEIcGgUhIYQQh0ZBSAghxKFREBJCCHFoFISEEEIcGgUhIYQQh0ZBSAghxKFREBJCCHFoFISEEEIcGgUhIYQQh0ZBSAghxKFREBJCCHFoFISEEEIcGgUhIYQQh0ZBSAghxKFREBJCCHFozvYuwIh//vOfxcXFwcHB9i6EoampCQBVxQZVxR5VxR5VxR5nq0pMTPzVr35l70IMcTEIz549Wya4kfjEXHsXwnClhC/X3Jv8eJS9C2GouFTd3aN6dNY4exfCcPvq3Z77mrBp3PpPKLrR1KPRBMYG2LsQhubyFpVa4zPR196FMMhvydTq++4TvOxdCIPyTlu36r5TsLu9C2HobVL2dvbAc5S9C2FSqObGzeLgVVSlUlEQshIYGPiLwOTXd++0dyEMb+9+p6ZLsOIPv7R3IQzZf/pW1tm66Hfz7F0Iw+mPipTdXU+mJ9i7EIbzmXxlt2rGyqn2LoSh9IsyRUd35DMT7V0IQ803t9o7VIEpE+xdCENzwR11a/fI+CB7F8LQUyLubekEz9PehTAJFb9I4eJV1ImToUP3CAkhhDg0CkJCCCEOjYKQEEKIQ6MgJIQQ4tAoCAkhhDg0CkJCCCEOjYKQEEKIQ6MgJIQQ4tAoCAkhhDg0CkJCCCEOjYKQEEKIQ6MgJIQQ4tAoCAkhhDg0CkJCCCEO7WcXhDUff3TK3jVYqamg4ESdvYsYRFLELxLau4hB5BduXBHZu4hBFFcqyxvtXcQgnWV1wmZ7FzGIuqpBIrV3EYM8qJdo2uxdxGDyDnTZ4+f+BK+iD+nnFYQ1HydFVUYttncZVgpOCayP38u1LPSf5yde8DeuZaHPXG/Jc19zLQs9H/NsfTGfa1k4eqp7x6vFXMtCl2g31QelXMvCEeNHPciu4VwW+jij5N5wZ+FP8yr6kH4+QVh76pWkqO2xeR8vsncl1ovbmBOaGb/3k4IWe1eiL2JZZsA3C/52rEhm70r08Rbv9yt47utTFzh10QpJfNv78ov5xVfa7V2JvoDYbR6VrxaXlyntXYk+39DV7k0flDZU2WWwY4rHyFTXnuyannqVvSvR54opI1FyD3LN8Py8n/JV9KFw8bBg69V+tO6lbZ9fAh6PzX1lQy7LXjFLt7+yKHLIiqoTnPiotJ518xjU5C/PEK5asWVfSvCQFQVhbdHhKjHr5uEQXUz/VLx8wW/eSvAfuqpEwiv/Uydh3TwM4utbvpCkPbEkY5rP0FXV2Fj+1d1W1s0DcK/q9W/lqfHzNsUM4WnlzS3CPHEH6+bekIv+/F3H/Ni41eGjh64qqUxSKGGfIaOhkB3hdydEjns21GXoqmpr19xoe8C6uROUmrzK3smhLvP8RwxdVV3dEHVb00GNH5oR5IWoMUNVEsDRq+hw+XkEIaImAwDmxk5KezqKfa8h/fsLDwxJnRHCtnUzbqKSD0QHDmVNAM9nbEr0WLatpbiFulIg0m8oawLCvH2Twn3Ztm7FbYjKgEe9h7ImIMTDa844L7atFbiDlgpg/BCGIAAEurvPCnJn21qJBrRWA6Gse9jGz21UrP8otq270IxOIRDoNpQ1AV4uIyZ4sY40FWToFQM+rH8ftnFzhr8r69YadAIKYPSQX6u5eBUdLj+PIIxY9Pvv1P/v46So7V/d2lL4+wh716MVMCslgGXTpkMn8/mR6SUZT4cPaUkAfCfNY5s4kqPnL5aGPXPmt/N4Q1oSAK+IuWwTR55Tcr0sKOXr5x4LG9KSAHiEPebBsqkiV1BVMXb2XxfGsv7sYyt3/6lsU63zTLGo2ifm/UTeEH++Atw8otmmmrq4VCb0DN4+w3+IP18Box4ZzzbVHgia1WL3kSsinVl/9rGVM3xYX3gb7kHhgvixGOLPDFy9ig6Tn889QkS+Ulj9AY699FGNvSuxVl3B/tzQXcORgtYQ8v+eF7B2OFLQGqIbJwv8lg1HClqjsbKoyHvBcKSgNZrrBHyPWcORgtaQNtwtc58wHClojTaJutbVZThS0BpdHZCMHJYU7PPTvYo+nJ/HiLBP5CuFxa/Yuwjrhae8z8HFyryETf+wdw2DhU37z0x71zBYSMySffauYbDA8Nk77V3DYH6hES/bu4bBvPxHpdm7hsHcxmDasP/Qn+hV9OH8jEaEhBBCiPUoCAkhhDg0CkJCCCEOjYKQEEKIQ6MgJIQQ4tAoCAkhhDg0jj8+UXv6/+/fu/nwRQB4/KUDr237/UJLz3la7GLDezLVCT7ZcDKfr33QJjJm1ZIt++IsborWdOjo/j8WVgIAYozso9ZyYuvhzKwaAEiIXLhtzUbWD+MbUZS3Jf3anMydy+aZbye7efTSv/dcqwOAsDk7nkx+IYK5j5qs6I3j3+SIAGBG2JwNS5exfhjfiAvn9m4pn75/4+K55tu11e4tuJArFgFAUFha/JIMXt8+asJTs09eN94rdtnl+Sz/HhW5+V99ds/gxei3VyY+ZrR5Y/Gis1XG3ylqwenZugcaGxuL/1tQVaF927EBqXHzNoVYu8dM55kywZf12k3dvNltimapi1Jy5HrtOXkrAPh4z4+KWx3AeqO1robXCk3sMxs84b2JpvYaUBeXVp1QGLzouzopNHrgndv/VXmHrwAAnqdvckwo64fx+6juF93VVCh7AcDdaXKgyzwPizvIPBCI1MXSXgCAE2Mftfaez26b2M3Tz2V92CNsq9KgWg6xGgA8XTDeh8WD83pd4IIgD0S5WtnAIk5eRTmG0yPC0+umLNX9WQO49JfNaZPW5T9kFxvek6Gu4NX4A30pCKCmMuvAusUFTWY7Xd26Zl1fCgKozMpe53f06sD3BZ8sztClIAB+Tf7yjFcP2bwBd+2x9Gts2t1849PPdSkIQHRxT/aeN2oZ7/PrT3UpCKBUdDH90w+P2rwBt/DUlnIWzdqupH9xTJeCAMSi3JOfpd+Q2/pTjVLUG6bgw2qsPPni2b4UBHCvJe/sV1srDdPALMmRk+f7Ig1A67ny86vKzO+/aqmLUvjOd1d1KQhA3nqOf/6duk5rqrKBqtn877tLcpCvS0EAQoXsCL/2X1b9u1JpcivVuhQEoOytuN2dKzG/p+j9ohvdfSkIoLeiofsz0X1rfqolGtxo7kssQKHGD82oNr+nKLML1BBLcaPDmgaWcfEqyj0cHhHWfLz3cwBrjld/vChSezjI9oufv/fR9oW/M7W7ncUuNrwng/LEhuOVQMy7m7esiwuGdnR4IJ+fvf/Q1PfXmRjDFRx9MwtISNr16QuzwoE6wYkNBzL5hTmHFs1aFwDd/mrAqs2H9sUFA00FBfuXZ1f+8fCJBbbsNXPzjeyLllsBQv6/cwDMXHsmdRIPEPI/XHCmLud80ZoI7VYykqPnL5YCy1fseCvCH5AU8f+efqZuz/GiJFv2mqnda2okx3ShtKAMmBq7bNf8iDBAJDz13MnrZd9fujBt8VwAvMWXNxqcDyPP+fqzfeLp+9kOB4H2tjvA5Pjn97HcGjsk8fTKROZL2jFl9Nva4WB7+X+XtAAB65PnpYV4Qjs6PFtVUVKUG7okjd0mbWVlV88BUeNnpU/1DwSaW4SZ/Mrq+tozEf4LTGyoZrFLWW1lNaNB+at8UXV5XVl47FQ2NbmFvpcUynxJO9rzXW1yOAh0qZoBXmT0y8Z301YXVzYJAV7whGUTPfwAqUxyTNDEFzTE6g8Zzaq/1yMGgvxcngp7xAtoa+/Jvq0RN2jq/UeON9VFpK5gdNF8d7tHLNUIxj4SNwrwGLl+2khmjweCmu5ipXMq6+FggxwKwNMdE73hBnR145YUYin8Q2FqY3h5h2GXEikUbZCP0XWx2MAiLl5FuYi7I8LTH2y/CLyU97Fua/PIVwrz1gCXtn1g8qOHxS42vCeTsp4PrNr8/rq+udDwuI2frogBKqtMHfvWcuLPhUBkujYFAYTHPa3tklvWBAAtJbk1QNKuvvnV4JSULe9GAjX1tSbe0oyivM9zED7D8v5jNw+fqQPmZKZO0qYaL2FT5kxA9M1h7U+VVeSJgJlr39JNlvrPS/jNjjBAJLbhhMIL547lImxqkMWG8rtSANNXz4/Q/g7CeIv3xwKQ3DVx6JLoxsl9Ykx94nEL06362lsrgAmetu+N3VhZ9Nk9TI6Pe0zvDVOTl6T1zYWGhCT+V3wA0FLP8lQmtfB4PeATo00sAIEBvPRYb6CV32JiAKe02KVT3A4gbOlAg9ht4wF0iG09lUnacPeEArzIseYSq0slBALdTJwp0aUoUwCewdoUBODn6/9ynC8gK2c7KHzQ1g3AeWbYI9rt0Lw8Rqb6AehtM3X6hUpzTQq4j3xqoIvzU6FOQG9tu/FxZJtEXaxEUKizqWQ11AuJGnDRJRYAN1dM8wMAielBYWcPAIzX6zLFHQA6NWwbWMLJqygXcTYI5dU/AFizVP/z/+KnXwLwQ7WJgKi11MViA4sCNkoPH98Xx3gtPJAH4GazidnR5no+sGoJY2wXnvK+9PDxU9rbhAFPnzp8XPrCLFYFmFd7LP0alq/4TarFlrJ7twDMjNO/iTgveg6AW1IJAPjO+8fO/VWpk36EqoSntpQjbcmSlB/hvZjarrz5vQhBKbusOYypUSEHAsaz3U17EO34b+zs/+ofUIYknl754ibmNqMhnj4A7ijYzY72dFQD86MYm4IGhs/OWrJwp6nbhErruzykLsmxmk54Bi8ze3CStKsbGG3yWIkulRDgBXgyNhp1G8UDmrvUJvo8NPUDMTA5kLGPqJf/qPXT3NKMHrek0nzX0Av3kU+xP4ypFwrA0425KagzPK0IrSHAzasoF3F2alR28wIwN4Z5GkjkpLnAhcpqwNg8WI2lLhYb2KSgNB/ApEDj62XqmoVATHQg6go+2ZCdzweML5YZ0FRwdP8fa5Cw4lkrk+PmG9kXMXPtWxE4et5SW7m4FJjhzzyPySdoBlAquQcMPnpQUpT39z0izFiQbGEBjqHavSevI3ZZBg85JRYb+yRGhu0TXz9ybuL4vqnRI+VAUEyisb2QtfOoafHWbcDd0NYCRIe2F2/9tqoCAAKsOkHwSsXlCiA1zsL+2lfuVoH1uFOp7gC8Q93RXFeeWS6qBiwulmlWWuwyelqI95dy0fGywMC+qdHj9YBP0DSbjmSqutskBBJ4FjbLlig7AV//roaDfJkQAEYPPnTQ6HhRqFQBbM4mHDHe26lYqbkmcvLqmxq9JgXcTR400abqBZx8XNAm6fmuQSMGDBfLMGmnXg2Ckw2jJyUpTAehrxtuq1HfCte+mc96JeACX2e2DSz46VxF7Y2zQfhT0Tfz+bs449+vbawEYnIPL/3jwHbulVnZ67Iadw0aBTYd2rtO2yxhxa5TVh7PK+T/OwdzMlMnAexPuGVDcvRve/ZoF44uWPsPK4/nFd24kAvtDTxWC17Cpi3ZLz+5pfzYc/0ra4Km708xFnVtV46UA0Ep/2ndHUtFQyuAqtfP9r/SklfyVV7bwPpPc9rLs6uBsbOfNx+DLJv16VC3At786/lfDvwhtZ4rP3+ufVbWVON/4M3tlrsEhsdtaxf8uf7qq/3nQ/uEbZtu02EUXZKzTYBncJKFZcNqiRKA7Iig/5VOfk0VX8lYZcqXtD/rqzck71JZNd/u5e+S2q3Ok6qzpX0vuTunjjOZW23dvYBT7d2u4oE54d6Khu6KbmMrQvvmUadZP2cg7mYu6dTA/ISA2xhM0eAHJUr6C3PBFJ+BYaXFBuRHwtmpUQDAlCjm54uIqCkP3cWG9zTt6taMTD5i3l1jflVLJb8GCSt2lRw+Lj18vGRzegKAwje3Cga3jEmIBAB+9puWVqIyyYr+cKbO2uHaRD/mRdZ37ETjDXU3HUvPfP5rvjUp23blze9F1t3Aa5MVSkWMV8SSQmOnxNs2HITqbtE9AAGpyc+fXvni6ZUv/jV59mQA1Wc+bLTcm91wUPTht5crELB+rlWnMrVWyxEVO+v9JQuzlix8PyEmCkD91SPmFg9b6qLsvNrO/LOTd1y16QYhy+Gg7hYgRifERb+XFPdeUtz2uGAegKY7unWhvp4JAJruHGxo16aYVCY5KLByLbKq9053L+MVZe8dCxOrvWIlgkJdVkxzWz/NbcWjI4MASNVFg27i2jgcdEYQACVudKALANDVjRtS830ADSQ9zFfUkGisacAC56+iXMDtIDScdNZOTz9cFxve07iWq1vXvJmlXTtj8Zm/pF2nUgYWy5zavBBAVulVZqPgdRnvn8o4Lt27691I8LPXGUtKoyRHj39TGvbMn6wcruluB/bT3jg05P/Cbzf947f7qzas3RGG0jPMRyzMkecUFJRZdwNPnlNwLFeMqU8s+3pjxuWNGV8vSZkKUe7J/8kxXCxTW1gOYHqStQtYR8XuW/ni6ZVLNg0sbIndlxwNIO+uyGxPAKKL1QCi55jJt3bRh1+cydOunbF2SDF+1s7wgZUvOxPCAJwTm/3kYa5L55nrV88ZJmXrOf7lM1ZnYXt5EwDfWItPkbr5v5wU915SxLO+uknOvrUw4Eu0mePxbJwvAGHNnQ8KBa8VCj4QNAmDfROsKOaB4K66wjDVeituqwSmFsvoSnFJ8x9YLJP2qDOAijaDJyju35ECcJ5g/XAwyg8AFG0oacD5BpRIoXCH+fVhDXKI1fD0QnwongxFvB88AbEUDRq2Ddjg9FWUK7gdhNzVcmJxxptZiFm1+dA+E5Oi+lbNYM6CBo439z8/YNa6NekJRpLSuLaGv+8Rhe9YOsQn6PpOemHpMzOAnKqbbJormk/uE4dtNTqraYrw0j4xELssc1r/qtHHMpdMB0T7Spn/8YS3cgHETrRirGmGh/dkNs0ahXkAonjGn7sH0F6+9dszeQhITX5+kzWDQa35QczPMe5joky0ZNWlpe5LudGkbP2y1srJc5mCDyDYk+XjDYbcRjH+ZfqGbo8LTtB9DhmdEDlh+7hRptZcG6HWFCuNplpv8T1zzwVO9mLOgrqMMJJS7b0VAPyc2C4W1eeKeD8E9d3lDPJC/BiYe2azG7fVgDumjTFcaHq7g10D8qPh7D1CX2N3X43ep+1n9IatfheLDdipE3wSfyAfiHl3r+WxYERIDFBpodFgASGTAD6rpp137tUB2PPplj2M1y+mv3MRM9caX/lpdF2M0RU0+nzHTgRKWVXVfVcmArDvi73ME2uvb/nkuvldYNLCmd/iTUzD9VxmG1GrxEjLIdaokANIHWc82Rsbi188WwUErP+l1WPBMS7egLH5X9MCPVh1MUzKgMD5EJ2z6ifpFoIiwd/mhbaG/Hz9n/X1f7b/a5lCCPDcTax1McYw1TycJgMVJhp7uToBvSa+ydCm6jXy5qy5uSLKFQOXkm4oAE+z19ggg21iXBEEiK1pYA6Hr6Icw9kRoU/UFACHj+sf3X7qxF8weHq6X4SlLhYbsKFNwciFOSxSEH0PV2SdPFGn/2pzPR9ICAnRvqHfmqWGdwQFl7LYFmQT7e3Aa4IivdeKqi6i/8Zh7bHod7YY3hGsFeQMaVUAgNw6Y4M/BnlxjciWeVEA7cWLvvirwZ4vjQ11FcBkL/MrPBV8YYvJeVFtCo6Nftv6FATg7jIGwLlqIWNgpOyoBqI8jC8cDXRn1cVwZrWl+ZzV1akrWzpZzYsCkDW8Vig42MC4XyeVtQ3kXJfkYKHgtVKJ/u2zKokMGD3Vl82SUR3DKU3tSM4Er1FOACqaNYz5dfUDMRDkqn8FfFDf2mvbvCh6caMB5+/pbhBqybsBwN/sjmhig6cMuw1DzmIDczh7FeUczgYhFqWtAfCX1Kc+0i63rPk4KfUwgJfSFtrcxYb3NHB164F8YGFOBuu9QOOefTcSqMnccPSqNgvrBCcWH8gHYtKmBgNA3OOroL0jqMvCvgaDJlSNGz1t+v6qnfq/duwIAzAn09yDgJNSZgK4mP63Iu2CPSH/w/RrAOakaP8xR8Qth/aOoO5KKqwt+nX2RQDLo9k8XOgaOynj8kb9X+u3BgGYvn9jhsnhIG9iGoDyY+k3arX360TCK+knr8Nw8Ce7LQaC/G2Zv/LgpQIVJUUfNuqysLGxWLsvzLxQC0FYfw8Y6x1q5FuiD89WAdFvL0x8zLZRk3v4f/gA8srMMok22JpbhO/wRYB3gqmtQQMsdQkInA+g/uo7dQYNBg0TLVA1KwDPUaz6+HomAMKau/+S6bJQKms4VtM5kHNunlM9AUVToa6Buqqh9kgT4OkVw3IppIvTZABSda7kvjbY2to1ubc1MDOS83BOdAeUPd+JDLo4RTB2KO2VKwH3EdY+NQEATvB3AdQQ9eWWvAM/mH/UwdXk+hrdKNBiA8u4eRXlIM5OjQKLPz6+9vDSzy9ti3Lb1v/i2txPBx7kzN8wIu0vcz+4WfxKBMsult/TrNbbOVkAkL98jeEmCgkrDukeeBB84ncgf+BLBK9bk56bkckvfDO+UL/9lr4B5ax9mxdmHcjPOrAui/mGbO4+slV7LDr74owFO/qff5iXunb5tc9zRN8seOeb/lbLV/Tv0z3prRVzcrIv5mTv0R8Fzliw460f8WOf8NTsk9enPrE+U7egJiJjyfTck9fLvj/23Pd6zWKXZegP/tpktwH4+Vq3XlQn7Pn4gLySlryzX+XpvTo5ft7ASK6xeNHZKsM92Nrb7gDw9ho8IGysvJ4HAFWvf2G4NzfrjdxGL5gew/+uslr/UQcgKjZuYH+1lvJVfFFU7JN9z8tb7OK/OiHsHF9UXX71Vf1dXsfPWm3Vfu5dqmYA7qOMrxeVNbwmkOltqOaRFDmaX9PJF1TpT+zzIscl6nLOJZHne0IgYzYY/XSMpfWoAx6Z96hzxW2NuEGd3aD3sp/LvP6/wfaez25rgkJd+56XHxE3bmRtZY9Y/4kLICjUJU5/Olb1QAbA1aYgBEI9cFsKsZQxYntU/1GHbpyXwtML08boXojyg1gKRRtK9Meq7gPPYFhsYBEHr6KcxN0RIYBFh344fmDNHN1Xj790IPfmIQsfOix2seE9B8jarb/bByDg6VN7d72bFKP7MjJm1eZDjMcE4zaWbE5f1b9P3+AGQ2LSWxvW7pjZ99xH2JwdK5ghF7HszIpnlvenTVj48hU7rH2O0Gq8xZdXpqQNLGMIS3tiPetjJVgJiVny1+TZqf13QsdGr09mve+oMQ1tNu+Prsedt/OpWf8x3lv3pY/3/IQnLewRY7FLQGzWUzHzBxbtes+PfdLUg4k/Fr/QCL21MICn79NxzH1HGYtlwAsOXp0QkWjVk3EeI9fHjJw8sC2A0+RQVwtnRIxyTotxSfTru965O01+1NX4tjI2Yy6W8XTHlECEmh9puOLJwIEuAIK88KS3NQ0s49xVlJNGPHhgftN2O9i9e3cvNK/v3mnvQhje3v1OTZdgxR9+ae9CGLL/9K2ss3XR76zc7mWInf6oSNnd9WS6NYvih975TL6yWzVjJasNp4dN6Rdlio7uyGdMPMJpJzXf3GrvUAWmTLB3IQzNBXdkrd0j4y1vWjucekrEmpZO8Gz/KDUkhIrXV2/i4FXUCc67d++2dyGGOD0iJIQQQoYaBSEhhBCHRkFICCHEoXHxHuHLL798q+bWL1KS7V0Iw/8WnJX2iKcnsdqDZNhcL6zo6umaONf6A3yH0q0LdeqenvCEcfYuhKGOf1et0YTMGOpFSNZpLG1SqTV+saa3MbAHafk9ler+mCgrjrcaBh3V8s6unkfCfrQn+38U90Xtvcoe+FixH8BwkKuS4xM5eBWdGDnx4MGD9i7EEBcfn0hOTlapVAo5y0NNhwlvHI8HXoDKqqXnQ27mox4AAh7hVlVjo3kAAsZwq6oJU1oABIzlVlUt01sABARwrCp3TlYVxMmqWrhbFQevosnJ3MpmLS6OCAkhhJBhQ/cICSGEODQKQkIIIQ6NgpAQQohDoyAkhBDi0CgICSGEODQKQkIIIQ6NgpAQQohDoyAkhBDi0CgICSGEODQKQkIIIQ6NgpAQQohDoyAkhBDi0CgICSGEODQKQkIIIQ6NgpAQQohDoyAkhBDi0CgICSGEODQKQkIIIQ6NgpAQQohDoyAkhBDi0CgICSGEODQKQkIIIQ6NgpAQQohD+z/f/GqqZgY2oAAAAABJRU5ErkJggg==",
      "text/plain": [
       "Cairo.CairoSurface{UInt32}(Ptr{Void} @0x000000000adf5cd0, 600.0, 100.0, #undef)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "render_policy(policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the policy\n",
    "Using the simulator from POMDPToolbox you can run many simulations to evaluate the learned policy. The `RolloutSimulator` type provides a fast simulator that will return the total discounted reward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rng = MersenneTwister(9)\n",
    "sim = RolloutSimulator(rng = rng, max_steps=10);\n",
    "n_episodes = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performance of the Q learning policy: 5.90 ± 0.00"
     ]
    }
   ],
   "source": [
    "qlearning_rewards = zeros(n_episodes)\n",
    "for ep in 1:n_episodes\n",
    "    r_tot = simulate(sim, mdp, policy)\n",
    "    qlearning_rewards[ep] = r_tot\n",
    "end\n",
    "@printf(\"Performance of the Q learning policy: %2.2f ± %2.2f\", mean(qlearning_rewards), std(qlearning_rewards))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's compare it to a random policy. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "random_policy = RandomPolicy(mdp);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performance of the random policy: 0.68 ± 1.42"
     ]
    }
   ],
   "source": [
    "random_rewards = zeros(n_episodes)\n",
    "for ep in 1:n_episodes\n",
    "    r_tot = simulate(sim, mdp, random_policy)\n",
    "    random_rewards[ep] = r_tot\n",
    "end\n",
    "@printf(\"Performance of the random policy: %2.2f ± %2.2f\", mean(random_rewards), std(random_rewards))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Try the following experiments:**\n",
    "\n",
    "- Vary the discount factor and visualize the policy.\n",
    "- Change the exploration parameters in the epsilon greedy policy."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.0",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
